#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <getopt.h>
#include <time.h>

#define BUFFER_SIZE 2048

int main(int argc, char *argv[]) {
  int listen_sock, *backend_socks;
  struct sockaddr_in listen_addr, client_addr, *backend_addrs;
  socklen_t addr_len = sizeof(client_addr);
  char buffer[BUFFER_SIZE];
  int backend_index = 0;
  int listen_port = 12201;
  const char *listen_ip = "0.0.0.0";
  char *backend_ips_ports[10]; // 假设最多支持10个后端服务器
  int num_backends = 0;
  int pkt_count = 0, sample_rate = 1;
  
  int received_count = 0;
  int forwarded_error_count = 0;
  int *forwarded_count;
  time_t start_time, current_time;

  int i, opt;
  while ((opt = getopt(argc, argv, "l:a:b:s:")) != -1) {
    switch (opt) {
      case 'l':
        listen_port = atoi(optarg);
        break;
      case 'a':
        listen_ip = optarg;
        break;
      case 'b': {
        char *token = strtok(optarg, ",");
        while (token != NULL) {
          backend_ips_ports[num_backends++] = token;
          token = strtok(NULL, ",");
        }
        break;
      }
      case 's':
        sample_rate = atoi(optarg);
        break;
      default:
        fprintf(stderr, "Usage: %s -l <listen_port> -a <listen_ip> -b <backend_ip:port> -s <sample_rate>\n", argv[0]);
        exit(EXIT_FAILURE);
    }
  }

  backend_socks = malloc(num_backends * sizeof(int));
  backend_addrs = malloc(num_backends * sizeof(struct sockaddr_in));
  forwarded_count = malloc(num_backends * sizeof(int));
  memset(forwarded_count, 0, num_backends * sizeof(int));

  // 创建监听socket
  if ((listen_sock = socket(AF_INET, SOCK_DGRAM, 0)) < 0) {
    perror("socket creation failed");
    exit(EXIT_FAILURE);
  }

  memset(&listen_addr, 0, sizeof(listen_addr));
  listen_addr.sin_family = AF_INET;
  listen_addr.sin_addr.s_addr = inet_addr(listen_ip);
  listen_addr.sin_port = htons(listen_port);

  // 绑定监听socket到指定端口
  if (bind(listen_sock, (const struct sockaddr *)&listen_addr, sizeof(listen_addr)) < 0) {
    perror("bind failed");
    close(listen_sock);
    exit(EXIT_FAILURE);
  }

  printf("UDP Load Balancer is listening on %s:%d\n", listen_ip, listen_port);

  // 预先创建后端服务器的sockets
  for (i = 0; i < num_backends; i++) {
    char *backend_ip = strdup(backend_ips_ports[i]);
    char *colon = strchr(backend_ip, ':');
    if (colon != NULL) {
      *colon = '\0';
      int backend_port = atoi(colon + 1);

      if ((backend_socks[i] = socket(AF_INET, SOCK_DGRAM, 0)) < 0) {
        perror("backend socket creation failed");
        exit(EXIT_FAILURE);
      }

      memset(&backend_addrs[i], 0, sizeof(backend_addrs[i]));
      backend_addrs[i].sin_family = AF_INET;
      backend_addrs[i].sin_port = htons(backend_port);
      if (inet_pton(AF_INET, backend_ip, &backend_addrs[i].sin_addr) <= 0) {
        perror("inet_pton failed");
        exit(EXIT_FAILURE);
      }
      printf("Create socket for backend %s:%d\n", backend_ip, backend_port);
    } else {
      fprintf(stderr, "Invalid backend format: %s\n", backend_ips_ports[i]);
      exit(EXIT_FAILURE);
    }

    free(backend_ip);
  }

  time(&start_time);

  while (1) {
    // 接收数据
    int n = recvfrom(listen_sock, buffer, BUFFER_SIZE, 0, (struct sockaddr *)&client_addr, &addr_len);
    if (n < 0) {
      perror("recvfrom failed");
      break;
    }
    received_count++;

    // 根据采样率下发数据包
    if (++pkt_count < sample_rate) {
      continue;
    } else {
      pkt_count = 0;
    }

    buffer[n] = '\0';
    if ('}' != buffer[n-1]) {
      forwarded_error_count++;
    }

    // 发送数据到后端服务器
    if (sendto(backend_socks[backend_index], buffer, n, 0, (const struct sockaddr *)&backend_addrs[backend_index], sizeof(backend_addrs[backend_index])) < 0) {
      perror("sendto failed");
      break;
    }
    forwarded_count[backend_index]++;

    // 轮询选择下一个后端服务器
    backend_index = (backend_index + 1) % num_backends;

    // 获取当前时间，若距离上次打印超过10秒，则打印统计信息
    time(&current_time);
    if (difftime(current_time, start_time) >= 10) {
      printf("%u: Received packets: %d Error packets %d\n", current_time, received_count, forwarded_error_count);
      for (i = 0; i < num_backends; i++) {
        printf("%u: Forwarded packets to backend %d: %d\n", current_time, i, forwarded_count[i]);
        forwarded_count[i] = 0; // 重置计数器
      }
      received_count = 0; // 重置计数器
      forwarded_error_count = 0; // 重置计数器
      start_time = current_time; // 重置开始时间
      fflush(stdout);
    }
  }

  close(listen_sock);
  for (i = 0; i < num_backends; i++) {
    close(backend_socks[i]);
  }
  free(backend_socks);
  free(backend_addrs);
  free(forwarded_count);
  return 0;
}

